#coding=utf-8

import os
from PIL import Image
import numpy as np
from glob import glob
from torchvision import datasets, transforms

def img2bin(src_path, save_path):
  preprocess = transforms.Compose([
      # transforms.Resize(256, Image.BICUBIC),
      # transforms.CenterCrop(224),
      transforms.Resize(700, Image.BICUBIC),
      transforms.CenterCrop(640),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

  i = 0
  in_files = os.listdir(src_path)
  for file in in_files:
      i = i + 1
      print(file, "===",i)
      files = os.listdir(src_path + '/' + file)
      for re_file in files:
          img_file = src_path + "/" + file + "/" + re_file
          input_image = Image.open(img_file).convert('RGB')
          input_tensor = preprocess(input_image)
          img = np.array(input_tensor).astype(np.float32)
          img.tofile(os.path.join(save_path, re_file.split('.')[0] + ".bin"))

def bin2info(bin_dir, info_data, width, height):
    bin_images = glob(os.path.join(bin_dir, '*.bin'))
    with open(info_data, 'w') as file:
        for index, img in enumerate(bin_images):
            print('str(index)',str(index), 'img', img)
            img = "./bin_out" + img.split("bin_out")[1]
            content = ' '.join([str(index), img, str(width), str(height)])
            file.write(content)
            file.write('\n')

if __name__ == "__main__":
    #src_path = "/home/DSFD/FaceDetection-DSFD_NPU/img/"
    src_path = "/home/datasets/WIDERFace/WIDER_val/images/"
    bin_path = "/home/DSFD/test/bin_out/"
    info_path = "/home/DSFD/test/info_out/info_result.info"
    img2bin(src_path, bin_path)
    bin2info(bin_path, info_path,640,640)